/**
 * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file flash_attention_score_grad_tiling_s1s2_bn2gs1s2_def.h
 * \brief
 */

#pragma once

#include <cstdint>
#include <register/tilingdata_base.h>
#include <tiling/tiling_api.h>

namespace optiling {

/////////////////////////////////////////////////////////////////////////
// S1S2_BNGS1S2
/////////////////////////////////////////////////////////////////////////
BEGIN_TILING_DATA_DEF(FlashAttentionScoreGradS1S2BNGS1S2BaseParams)
TILING_DATA_FIELD_DEF(int64_t, b);
TILING_DATA_FIELD_DEF(int64_t, n2);
TILING_DATA_FIELD_DEF(int64_t, g);
TILING_DATA_FIELD_DEF(int64_t, s1);
TILING_DATA_FIELD_DEF(int64_t, s2);
TILING_DATA_FIELD_DEF(int64_t, d);
TILING_DATA_FIELD_DEF(float, scaleValue);
TILING_DATA_FIELD_DEF(float, keepProb);
TILING_DATA_FIELD_DEF(int64_t, s1Token); // pre_tokens
TILING_DATA_FIELD_DEF(int64_t, s2Token); // next_tokens
TILING_DATA_FIELD_DEF(uint32_t, sparseMode);
TILING_DATA_FIELD_DEF(uint32_t, isSparse);
TILING_DATA_FIELD_DEF(int64_t, attenMaskS2Size);
TILING_DATA_FIELD_DEF(uint32_t, coreNum);
TILING_DATA_FIELD_DEF(uint32_t, attenMaskCompressMode);
TILING_DATA_FIELD_DEF(int64_t, qStartIdx);
TILING_DATA_FIELD_DEF(int64_t, kvStartIdx);
TILING_DATA_FIELD_DEF(int64_t, pseAlibiBaseS1);
TILING_DATA_FIELD_DEF(int64_t, pseAlibiBaseS2);
TILING_DATA_FIELD_DEF(uint32_t, pseType);
TILING_DATA_FIELD_DEF(uint32_t, pseOptional);
TILING_DATA_FIELD_DEF(uint32_t, pseShapeType);
TILING_DATA_FIELD_DEF(uint32_t, pseDtype);
TILING_DATA_FIELD_DEF(uint32_t, attenMaskOptional);
TILING_DATA_FIELD_DEF(uint32_t, attenMaskDtype);
TILING_DATA_FIELD_DEF(uint32_t, attenMaskShapeType);
TILING_DATA_FIELD_DEF(uint32_t, pad);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGradS1S2BNGS1S2BaseParamsOp, FlashAttentionScoreGradS1S2BNGS1S2BaseParams)

BEGIN_TILING_DATA_DEF(FlashAttentionScoreGradS1S2BNGS1S2SplitCoreParams)
TILING_DATA_FIELD_DEF(int64_t, s1Outer);
TILING_DATA_FIELD_DEF(uint32_t, s1CvRatio);
TILING_DATA_FIELD_DEF(uint32_t, s1Inner);
TILING_DATA_FIELD_DEF(uint32_t, s1CvInner);
TILING_DATA_FIELD_DEF(uint32_t, s1Tail);
TILING_DATA_FIELD_DEF(uint32_t, s1CvTail);
TILING_DATA_FIELD_DEF(int64_t, s2Outer);
TILING_DATA_FIELD_DEF(uint32_t, s2CvRatio);
TILING_DATA_FIELD_DEF(uint32_t, s2Inner);
TILING_DATA_FIELD_DEF(uint32_t, s2Tail);
TILING_DATA_FIELD_DEF(uint32_t, baseMN);
TILING_DATA_FIELD_DEF(uint32_t, sfmgdOuter);
TILING_DATA_FIELD_DEF(uint32_t, sfmgdFactor);
TILING_DATA_FIELD_DEF(uint32_t, sfmgdTail);
TILING_DATA_FIELD_DEF(uint32_t, blockOuter);
TILING_DATA_FIELD_DEF(int64_t, bandIdx);
END_TILING_DATA_DEF;
// 固定写法不能换行，会失败
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGradS1S2BNGS1S2SplitCoreParamsOp,
                           FlashAttentionScoreGradS1S2BNGS1S2SplitCoreParams)

BEGIN_TILING_DATA_DEF(BlockNumListParams)
TILING_DATA_FIELD_DEF_ARR(int64_t, 50, blockStarts);
TILING_DATA_FIELD_DEF_ARR(int64_t, 50, blockEnds);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(BlockNumListParamsOp, BlockNumListParams)

BEGIN_TILING_DATA_DEF(FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
TILING_DATA_FIELD_DEF_STRUCT(FlashAttentionScoreGradS1S2BNGS1S2BaseParams, s1s2BNGS1S2BaseParams);
TILING_DATA_FIELD_DEF_STRUCT(FlashAttentionScoreGradS1S2BNGS1S2SplitCoreParams, s1s2BNGS1S2SplitCoreParams);
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm1TilingData);
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm2TilingData);
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm3TilingData);
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxGradTilingData);
TILING_DATA_FIELD_DEF_STRUCT(BlockNumListParams, s1s2BNGS1S2BlockNumList);
TILING_DATA_FIELD_DEF_STRUCT(PreParams, preTilingData);
TILING_DATA_FIELD_DEF_STRUCT(PostParams, postTilingData);
END_TILING_DATA_DEF;

// TND  1000000xxxxxxxx3x434(5)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000000032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000000033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000000031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000111033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000111032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000111031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000011033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000011032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000011031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000101033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000101032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000101031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000001033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000001032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000001031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000110033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000110032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000110031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000010033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000010032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000010031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000100033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000100032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000100031434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001000032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001000033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001111033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001111032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001011033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001011032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001101033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001101032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001001033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001001032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001110033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001110032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001010033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001010032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001100033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000001100032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
// mm345 Nz输出
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010000032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010000033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010111033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010111032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010011033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010011032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010101033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010101032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010001033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010001032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010110033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010110032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010010033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010010032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010100033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000010100032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011000032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011000033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011111033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011111032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011011033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011011032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011101033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011101032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011001033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011001032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011110033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011110032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011010033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011010032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011100033434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000011100032434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
// fp32
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000000001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000111001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000011001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000101001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000001001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000110001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000010001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000100001434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000000011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000111011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000011011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000101011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000001011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000110011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000010011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000100011434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000000021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000111021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000011021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000101021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000001021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000110021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000010021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)
REGISTER_TILING_DATA_CLASS(FlashAttentionScoreGrad_10000000000100021434, FlashAttentionScoreGradTilingDataS1s2Bn2gs1s2)

} // namespace optiling
